In this notebook, we'll be building a generative adversarial network (GAN) trained on the MNIST dataset. From this, we'll be able to generate new handwritten digits!
GANs were first reported on in 2014 from Ian Goodfellow and others in Yoshua Bengio's lab. Since then, GANs have exploded in popularity. Here are a few examples to check out:
The idea behind GANs is that you have two networks, a generator $G$ and a discriminator $D$, competing against each other. The generator makes "fake" data to pass to the discriminator. The discriminator also sees real training data and predicts if the data it's received is real or fake.
- The generator is trained to fool the discriminator, it wants to output data that looks as close as possible to real, training data.
- The discriminator is a classifier that is trained to figure out which data is real and which is fake.
What ends up happening is that the generator learns to make data that is indistinguishable from real data to the discriminator.
The general structure of a GAN is shown in the diagram above, using MNIST images as data. The latent sample is a random vector that the generator uses to construct its fake images. This is often called a latent vector and that vector space is called latent space. As the generator trains, it figures out how to map latent vectors to recognizable images that can fool the discriminator.
If you're interested in generating only new images, you can throw out the discriminator after training. In this notebook, I'll show you how to define and train these adversarial networks in PyTorch and generate new images!
In [ ]:
%matplotlib inline
import numpy as np
import torch
import matplotlib.pyplot as plt
In [ ]:
from torchvision import datasets
import torchvision.transforms as transforms
# number of subprocesses to use for data loading
num_workers = 0
# how many samples per batch to load
batch_size = 64
# convert data to torch.FloatTensor
transform = transforms.ToTensor()
# get the training datasets
train_data = datasets.MNIST(root='data', train=True,
download=True, transform=transform)
# prepare data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
num_workers=num_workers)
In [ ]:
# obtain one batch of training images
dataiter = iter(train_loader)
images, labels = dataiter.next()
images = images.numpy()
# get one image from the batch
img = np.squeeze(images[0])
fig = plt.figure(figsize = (3,3))
ax = fig.add_subplot(111)
ax.imshow(img, cmap='gray')
The discriminator network is going to be a pretty typical linear classifier. To make this network a universal function approximator, we'll need at least one hidden layer, and these hidden layers should have one key attribute:
All hidden layers will have a Leaky ReLu activation function applied to their outputs.
We should use a leaky ReLU to allow gradients to flow backwards through the layer unimpeded. A leaky ReLU is like a normal ReLU, except that there is a small non-zero output for negative input values.
We'll also take the approach of using a more numerically stable loss function on the outputs. Recall that we want the discriminator to output a value 0-1 indicating whether an image is real or fake.
We will ultimately use BCEWithLogitsLoss, which combines a
sigmoid
activation function and and binary cross entropy loss in one function.
So, our final output layer should not have any activation function applied to it.
In [1]:
import torch.nn as nn
import torch.nn.functional as F
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_dim, output_size):
super(Discriminator, self).__init__()
# define all layers
def forward(self, x):
# flatten image
# pass x through all layers
# apply leaky relu activation to all hidden layers
return x
The generator network will be almost exactly the same as the discriminator network, except that we're applying a tanh activation function to our output layer.
The generator has been found to perform the best with $tanh$ for the generator output, which scales the output to be between -1 and 1, instead of 0 and 1.
Recall that we also want these outputs to be comparable to the real input pixel values, which are read in as normalized values between 0 and 1.
So, we'll also have to scale our real input images to have pixel values between -1 and 1 when we train the discriminator.
I'll do this in the training loop, later on.
In [ ]:
class Generator(nn.Module):
def __init__(self, input_size, hidden_dim, output_size):
super(Generator, self).__init__()
# define all layers
def forward(self, x):
# pass x through all layers
# final layer should have tanh applied
return x
In [ ]:
# Discriminator hyperparams
# Size of input image to discriminator (28*28)
input_size =
# Size of discriminator output (real or fake)
d_output_size =
# Size of *last* hidden layer in the discriminator
d_hidden_size =
# Generator hyperparams
# Size of latent vector to give to generator
z_size =
# Size of discriminator output (generated image)
g_output_size =
# Size of *first* hidden layer in the generator
g_hidden_size =
In [ ]:
# instantiate discriminator and generator
D = Discriminator(input_size, d_hidden_size, d_output_size)
G = Generator(z_size, g_hidden_size, g_output_size)
# check that they are as you expect
print(D)
print()
print(G)
Now we need to calculate the losses.
- For the discriminator, the total loss is the sum of the losses for real and fake images,
d_loss = d_real_loss + d_fake_loss
.- Remember that we want the discriminator to output 1 for real images and 0 for fake images, so we need to set up the losses to reflect that.
The losses will by binary cross entropy loss with logits, which we can get with BCEWithLogitsLoss. This combines a sigmoid
activation function and and binary cross entropy loss in one function.
For the real images, we want D(real_images) = 1
. That is, we want the discriminator to classify the the real images with a label = 1, indicating that these are real. To help the discriminator generalize better, the labels are reduced a bit from 1.0 to 0.9. For this, we'll use the parameter smooth
; if True, then we should smooth our labels. In PyTorch, this looks like labels = torch.ones(size) * 0.9
The discriminator loss for the fake data is similar. We want D(fake_images) = 0
, where the fake images are the generator output, fake_images = G(z)
.
The generator loss will look similar only with flipped labels. The generator's goal is to get D(fake_images) = 1
. In this case, the labels are flipped to represent that the generator is trying to fool the discriminator into thinking that the images it generates (fakes) are real!
In [ ]:
# Calculate losses
def real_loss(D_out, smooth=False):
# compare logits to real labels
# smooth labels if smooth=True
loss =
return loss
def fake_loss(D_out):
# compare logits to fake labels
loss =
return loss
In [ ]:
import torch.optim as optim
# learning rate for optimizers
lr = 0.002
# Create optimizers for the discriminator and generator
d_optimizer =
g_optimizer =
Training will involve alternating between training the discriminator and the generator. We'll use our functions real_loss
and fake_loss
to help us calculate the discriminator losses in all of the following cases.
As we train, we'll also print out some loss statistics and save some generated "fake" samples.
In [ ]:
import pickle as pkl
# training hyperparams
num_epochs = 40
# keep track of loss and generated, "fake" samples
samples = []
losses = []
print_every = 400
# Get some fixed data for sampling. These are images that are held
# constant throughout training, and allow us to inspect the model's performance
sample_size=16
fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
fixed_z = torch.from_numpy(fixed_z).float()
# train the network
D.train()
G.train()
for epoch in range(num_epochs):
for batch_i, (real_images, _) in enumerate(train_loader):
batch_size = real_images.size(0)
## Important rescaling step ##
real_images = real_images*2 - 1 # rescale input images from [0,1) to [-1, 1)
# ============================================
# TRAIN THE DISCRIMINATOR
# ============================================
# 1. Train with real images
# Compute the discriminator losses on real images
# use smoothed labels
# 2. Train with fake images
# Generate fake images
z = np.random.uniform(-1, 1, size=(batch_size, z_size))
z = torch.from_numpy(z).float()
fake_images = G(z)
# Compute the discriminator losses on fake images
# add up real and fake losses and perform backprop
d_loss =
# =========================================
# TRAIN THE GENERATOR
# =========================================
# 1. Train with fake images and flipped labels
# Generate fake images
# Compute the discriminator losses on fake images
# using flipped labels!
# perform backprop
g_loss =
# Print some loss stats
if batch_i % print_every == 0:
# print discriminator and generator loss
print('Epoch [{:5d}/{:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'.format(
epoch+1, num_epochs, d_loss.item(), g_loss.item()))
## AFTER EACH EPOCH##
# append discriminator loss and generator loss
losses.append((d_loss.item(), g_loss.item()))
# generate and save sample, fake images
G.eval() # eval mode for generating samples
samples_z = G(fixed_z)
samples.append(samples_z)
G.train() # back to train mode
# Save training generator samples
with open('train_samples.pkl', 'wb') as f:
pkl.dump(samples, f)
In [ ]:
fig, ax = plt.subplots()
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator')
plt.plot(losses.T[1], label='Generator')
plt.title("Training Losses")
plt.legend()
In [ ]:
# helper function for viewing a list of passed in sample images
def view_samples(epoch, samples):
fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
for ax, img in zip(axes.flatten(), samples[epoch]):
img = img.detach()
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
In [ ]:
# Load samples from generator, taken while training
with open('train_samples.pkl', 'rb') as f:
samples = pkl.load(f)
These are samples from the final training epoch. You can see the generator is able to reproduce numbers like 1, 7, 3, 2. Since this is just a sample, it isn't representative of the full range of images this generator can make.
In [ ]:
# -1 indicates final epoch's samples (the last in the list)
view_samples(-1, samples)
Below I'm showing the generated images as the network was training, every 10 epochs.
In [ ]:
rows = 10 # split epochs into 10, so 100/10 = every 10 epochs
cols = 6
fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)
for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):
for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
img = img.detach()
ax.imshow(img.reshape((28,28)), cmap='Greys_r')
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
It starts out as all noise. Then it learns to make only the center white and the rest black. You can start to see some number like structures appear out of the noise like 1s and 9s.
In [ ]:
# randomly generated, new latent vectors
sample_size=16
rand_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
rand_z = torch.from_numpy(rand_z).float()
G.eval() # eval mode
# generated samples
rand_images = G(rand_z)
# 0 indicates the first set of samples in the passed in list
# and we only have one batch of samples, here
view_samples(0, [rand_images])